{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center><img src=\"../images/DLI Header.png\" alt=\"Header\" style=\"width: 400px;\"/></center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Getting Started with AI on Jetson Nano\n",
    "### Interactive Regression Tool"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is an interactive data collection, training, and testing tool, provided as part of the NVIDIA Deep Learning Institute (DLI) course, \"Getting Started with AI on Jetson Nano\". It is designed to be run on the Jetson Nano in conjunction with the detailed instructions provided in the online DLI course pages. \n",
    "\n",
    "To start the tool, set the **Camera** and **Task** code cell definitions, then execute all cells.  The interactive tool widgets at the bottom of the notebook will display.  The tool can then be used to gather data, add data, train data, and test data in an iterative and interactive fashion! \n",
    "\n",
    "The explanations in this notebook are intentionally minimal to provide a streamlined experience.  Please see the DLI course pages for detailed information on tool operation and project creation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Camera\n",
    "First, create your camera and set it to `running`.  Uncomment the appropriate camera selection lines, depending on which type of camera you're using (USB or CSI). This cell may take several seconds to execute.\n",
    "<div style=\"border:2px solid black; background-color:#e3ffb3; font-size:12px; padding:8px; margin-top: auto;\">\n",
    "    <h4><i>Tip</i></h4>\n",
    "    <p>There can only be one instance of CSICamera or USBCamera at a time.  Before starting this notebook, make sure you have executed the final \"shutdown\" cell in any other notebooks you have run so that the camera is released. \n",
    "    </p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check device number\n",
    "!ls -ltrh /dev/video*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jetcam.usb_camera import USBCamera\n",
    "from jetcam.csi_camera import CSICamera\n",
    "\n",
    "# for USB Camera (Logitech C270 webcam), uncomment the following line\n",
    "camera = USBCamera(width=224, height=224, capture_device=0) # confirm the capture_device number\n",
    "\n",
    "# for CSI Camera (Raspberry Pi Camera Module V2), uncomment the following line\n",
    "# camera = CSICamera(width=224, height=224, capture_device=0) # confirm the capture_device number\n",
    "\n",
    "camera.running = True\n",
    "print(\"camera created\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Task\n",
    "Next, define your project `TASK` and what `CATEGORIES` of data you will collect.  You may optionally define space for multiple `DATASETS` with names of your choosing. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Uncomment/edit the associated lines for the XY regression task you're building and execute the cell. This cell should only take a few seconds to execute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "from dataset import XYDataset\n",
    "\n",
    "TASK = 'face'\n",
    "# TASK = 'diy'\n",
    "\n",
    "CATEGORIES = ['nose', 'left_eye', 'right_eye']\n",
    "# CATEGORIES = [ 'diy_1', 'diy_2', 'diy_3']\n",
    "\n",
    "DATASETS = ['A', 'B']\n",
    "# DATASETS = ['A', 'B', 'C']\n",
    "\n",
    "TRANSFORMS = transforms.Compose([\n",
    "    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "datasets = {}\n",
    "for name in DATASETS:\n",
    "    datasets[name] = XYDataset('../data/regression/' + TASK + '_' + name, CATEGORIES, TRANSFORMS)\n",
    "    \n",
    "print(\"{} task with {} categories defined\".format(TASK, CATEGORIES))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up the data directory location if not there already\n",
    "DATA_DIR = '/nvdli-nano/data/regression/'\n",
    "!mkdir -p {DATA_DIR}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Collection\n",
    "Execute the cell below to create the data collection tool widget. This cell should only take a few seconds to execute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import ipywidgets\n",
    "import traitlets\n",
    "from IPython.display import display\n",
    "from jetcam.utils import bgr8_to_jpeg\n",
    "from jupyter_clickable_image_widget import ClickableImageWidget\n",
    "\n",
    "\n",
    "# initialize active dataset\n",
    "dataset = datasets[DATASETS[0]]\n",
    "\n",
    "# unobserve all callbacks from camera in case we are running this cell for second time\n",
    "camera.unobserve_all()\n",
    "\n",
    "# create image preview\n",
    "camera_widget = ClickableImageWidget(width=camera.width, height=camera.height)\n",
    "with open(\"../images/ready_img.jpg\", \"rb\") as file:\n",
    "    default_image = file.read()\n",
    "snapshot_widget = ipywidgets.Image(value=default_image, width=camera.width, height=camera.height)\n",
    "traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)\n",
    "\n",
    "# create widgets\n",
    "dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='dataset')\n",
    "category_widget = ipywidgets.Dropdown(options=dataset.categories, description='category')\n",
    "count_widget = ipywidgets.IntText(description='count')\n",
    "\n",
    "# manually update counts at initialization\n",
    "count_widget.value = dataset.get_count(category_widget.value)\n",
    "\n",
    "# sets the active dataset\n",
    "def set_dataset(change):\n",
    "    global dataset\n",
    "    dataset = datasets[change['new']]\n",
    "    count_widget.value = dataset.get_count(category_widget.value)\n",
    "dataset_widget.observe(set_dataset, names='value')\n",
    "\n",
    "# update counts when we select a new category\n",
    "def update_counts(change):\n",
    "    count_widget.value = dataset.get_count(change['new'])\n",
    "category_widget.observe(update_counts, names='value')\n",
    "\n",
    "\n",
    "def save_snapshot(_, content, msg):\n",
    "    if content['event'] == 'click':\n",
    "        data = content['eventData']\n",
    "        x = data['offsetX']\n",
    "        y = data['offsetY']\n",
    "        \n",
    "        # save to disk\n",
    "        dataset.save_entry(category_widget.value, camera.value, x, y)\n",
    "        \n",
    "        # display saved snapshot\n",
    "        snapshot = camera.value.copy()\n",
    "        snapshot = cv2.circle(snapshot, (x, y), 8, (0, 255, 0), 3)\n",
    "        snapshot_widget.value = bgr8_to_jpeg(snapshot)\n",
    "        count_widget.value = dataset.get_count(category_widget.value)\n",
    "        \n",
    "camera_widget.on_msg(save_snapshot)\n",
    "\n",
    "data_collection_widget = ipywidgets.VBox([\n",
    "    ipywidgets.HBox([camera_widget, snapshot_widget]),\n",
    "    dataset_widget,\n",
    "    category_widget,\n",
    "    count_widget\n",
    "])\n",
    "\n",
    "# display(data_collection_widget)\n",
    "print(\"data_collection_widget created\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model\n",
    "Execute the following cell to define the neural network and adjust the fully connected layer (`fc`) to match the outputs required for the project.  This cell may take several seconds to execute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "\n",
    "device = torch.device('cuda')\n",
    "output_dim = 2 * len(dataset.categories)  # x, y coordinate for each category\n",
    "\n",
    "# ALEXNET\n",
    "# model = torchvision.models.alexnet(pretrained=True)\n",
    "# model.classifier[-1] = torch.nn.Linear(4096, output_dim)\n",
    "\n",
    "# SQUEEZENET \n",
    "# model = torchvision.models.squeezenet1_1(pretrained=True)\n",
    "# model.classifier[1] = torch.nn.Conv2d(512, output_dim, kernel_size=1)\n",
    "# model.num_classes = len(dataset.categories)\n",
    "\n",
    "# RESNET 18\n",
    "model = torchvision.models.resnet18(pretrained=True)\n",
    "model.fc = torch.nn.Linear(512, output_dim)\n",
    "\n",
    "# RESNET 34\n",
    "# model = torchvision.models.resnet34(pretrained=True)\n",
    "# model.fc = torch.nn.Linear(512, output_dim)\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "model_save_button = ipywidgets.Button(description='save model')\n",
    "model_load_button = ipywidgets.Button(description='load model')\n",
    "model_path_widget = ipywidgets.Text(description='model path', value='/nvdli-nano/data/regression/my_xy_model.pth')\n",
    "\n",
    "def load_model(c):\n",
    "    model.load_state_dict(torch.load(model_path_widget.value))\n",
    "model_load_button.on_click(load_model)\n",
    "    \n",
    "def save_model(c):\n",
    "    torch.save(model.state_dict(), model_path_widget.value)\n",
    "model_save_button.on_click(save_model)\n",
    "\n",
    "model_save_button.click()\n",
    "\n",
    "model_widget = ipywidgets.VBox([\n",
    "    model_path_widget,\n",
    "    ipywidgets.HBox([model_load_button, model_save_button])\n",
    "])\n",
    "\n",
    "# display(model_widget)\n",
    "print(\"model configured and model_widget created\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Live  Execution\n",
    "Execute the cell below to set up the live execution widget.  This cell should only take a few seconds to execute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import threading\n",
    "import time\n",
    "from utils import preprocess\n",
    "import torch.nn.functional as F\n",
    "\n",
    "state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')\n",
    "with open(\"../images/ready_img.jpg\", \"rb\") as file:\n",
    "    default_image = file.read()\n",
    "prediction_widget = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height, value=default_image)\n",
    "\n",
    "def live(state_widget, model, camera, prediction_widget):\n",
    "    global dataset\n",
    "    while state_widget.value == 'live':\n",
    "        image = camera.value\n",
    "        preprocessed = preprocess(image)\n",
    "        output = model(preprocessed).detach().cpu().numpy().flatten()\n",
    "        category_index = dataset.categories.index(category_widget.value)\n",
    "        x = output[2 * category_index]\n",
    "        y = output[2 * category_index + 1]\n",
    "        \n",
    "        x = int(camera.width * (x / 2.0 + 0.5))\n",
    "        y = int(camera.height * (y / 2.0 + 0.5))\n",
    "        \n",
    "        prediction = image.copy()\n",
    "        prediction = cv2.circle(prediction, (x, y), 8, (255, 0, 0), 3)\n",
    "        prediction_widget.value = bgr8_to_jpeg(prediction)\n",
    "            \n",
    "def start_live(change):\n",
    "    if change['new'] == 'live':\n",
    "        execute_thread = threading.Thread(target=live, args=(state_widget, model, camera, prediction_widget))\n",
    "        execute_thread.start()\n",
    "\n",
    "state_widget.observe(start_live, names='value')\n",
    "\n",
    "live_execution_widget = ipywidgets.VBox([\n",
    "    prediction_widget,\n",
    "    state_widget\n",
    "])\n",
    "\n",
    "# display(live_execution_widget)\n",
    "print(\"live_execution_widget created\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training and Evaluation\n",
    "Execute the following cell to define the trainer, and the widget to control it. This cell may take several seconds to execute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 8\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)\n",
    "\n",
    "epochs_widget = ipywidgets.IntText(description='epochs', value=1)\n",
    "eval_button = ipywidgets.Button(description='evaluate')\n",
    "train_button = ipywidgets.Button(description='train')\n",
    "loss_widget = ipywidgets.FloatText(description='loss')\n",
    "progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')\n",
    "\n",
    "def train_eval(is_training):\n",
    "    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget\n",
    "    \n",
    "    try:\n",
    "        train_loader = torch.utils.data.DataLoader(\n",
    "            dataset,\n",
    "            batch_size=BATCH_SIZE,\n",
    "            shuffle=True\n",
    "        )\n",
    "\n",
    "        state_widget.value = 'stop'\n",
    "        train_button.disabled = True\n",
    "        eval_button.disabled = True\n",
    "        time.sleep(1)\n",
    "\n",
    "        if is_training:\n",
    "            model = model.train()\n",
    "        else:\n",
    "            model = model.eval()\n",
    "\n",
    "        while epochs_widget.value > 0:\n",
    "            i = 0\n",
    "            sum_loss = 0.0\n",
    "            error_count = 0.0\n",
    "            for images, category_idx, xy in iter(train_loader):\n",
    "                # send data to device\n",
    "                images = images.to(device)\n",
    "                xy = xy.to(device)\n",
    "\n",
    "                if is_training:\n",
    "                    # zero gradients of parameters\n",
    "                    optimizer.zero_grad()\n",
    "\n",
    "                # execute model to get outputs\n",
    "                outputs = model(images)\n",
    "\n",
    "                # compute MSE loss over x, y coordinates for associated categories\n",
    "                loss = 0.0\n",
    "                for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):\n",
    "                    loss += torch.mean((outputs[batch_idx][2 * cat_idx:2 * cat_idx+2] - xy[batch_idx])**2)\n",
    "                loss /= len(category_idx)\n",
    "\n",
    "                if is_training:\n",
    "                    # run backpropogation to accumulate gradients\n",
    "                    loss.backward()\n",
    "\n",
    "                    # step optimizer to adjust parameters\n",
    "                    optimizer.step()\n",
    "\n",
    "                # increment progress\n",
    "                count = len(category_idx.flatten())\n",
    "                i += count\n",
    "                sum_loss += float(loss)\n",
    "                progress_widget.value = i / len(dataset)\n",
    "                loss_widget.value = sum_loss / i\n",
    "                \n",
    "            if is_training:\n",
    "                epochs_widget.value = epochs_widget.value - 1\n",
    "            else:\n",
    "                break\n",
    "    except e:\n",
    "        pass\n",
    "    model = model.eval()\n",
    "\n",
    "    train_button.disabled = False\n",
    "    eval_button.disabled = False\n",
    "    state_widget.value = 'live'\n",
    "    \n",
    "train_button.on_click(lambda c: train_eval(is_training=True))\n",
    "eval_button.on_click(lambda c: train_eval(is_training=False))\n",
    "    \n",
    "train_eval_widget = ipywidgets.VBox([\n",
    "    epochs_widget,\n",
    "    progress_widget,\n",
    "    loss_widget,\n",
    "    ipywidgets.HBox([train_button, eval_button])\n",
    "])\n",
    "\n",
    "# display(train_eval_widget)\n",
    "print(\"trainer configured and train_eval_widget created\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Display the Interactive Tool!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The interactive tool includes widgets for data collection, training, and testing.  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center><img src=\"../images/regression_XY_tool_key2.png\" alt=\"tool key\" width=500/></center>\n",
    "<br>\n",
    "<center><img src=\"../images/regression_XY_tool_key1.png\" alt=\"tool key\"/></center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Execute the cell below to create and display the full interactive widget.  Follow the instructions in the online DLI course pages to build your project."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Combine all the widgets into one display\n",
    "all_widget = ipywidgets.VBox([\n",
    "    ipywidgets.HBox([data_collection_widget, live_execution_widget]), \n",
    "    train_eval_widget,\n",
    "    model_widget\n",
    "])\n",
    "\n",
    "display(all_widget)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1 style=\"background-color:#76b900;\"></h1>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Before you go...<br><br>Shut down the camera and/or notebook kernel to release the camera resource"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Attention!  Execute this cell before moving to another notebook\n",
    "# The USB camera application only requires that the notebook be reset\n",
    "# The CSI camera application requires that the 'camera' object be specifically released\n",
    "\n",
    "import os\n",
    "import IPython\n",
    "\n",
    "if type(camera) is CSICamera:\n",
    "    print(\"Ignore 'Exception in thread' tracebacks\\n\")\n",
    "    camera.cap.release()\n",
    "\n",
    "os._exit(00)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Return to the DLI course pages for the next instructions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center><img src=\"../images/DLI Header.png\" alt=\"Header\" style=\"width: 400px;\"/></center>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
